import numpy as np
import cv2
import os
import time
from PIL import Image
import io
import copy
import torch
from models.experimental import attempt_load
from utils.torch_utils import select_device
from utils.general import (
    check_img_size, non_max_suppression, apply_classifier, scale_coords,
    xyxy2xywh, xywh2xyxy, strip_optimizer)
from torchvision import transforms
# import random
import fall_TSM_Module
import random
from torch.nn import functional as F

weights = 'weights/yolov5x.pt'
imgsize = 640
confthres = 0.4
iouthres = 0.5
frame_number = 8
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
         'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
         'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
         'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
         'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
         'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
         'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
         'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
         'hair drier', 'toothbrush']

# Initialize
device = select_device('0')
half = device.type != 'cpu'
model = attempt_load(weights, map_location=device)
names = model.module.names if hasattr(model, 'module') else model.names
if half:
    model.half()  # to FP16
imgsz = check_img_size(imgsize, s=model.stride.max())
# image data list
imagedatalist = []
# result list
resultlist = []
resultPoseList = []
# detected index
detected = []
posedetected = []
proposalFlag = False

from efficientnet_pytorch import EfficientNet

# classnames = ["abnormal"->0, "normal"->1]
poseNames = ["bend", "fall", "jump", "lie", "ride", "run", "sit", "squat", "stand", "walk"]
# TODO 2 classes versus 10 classes
# posemodel = EfficientNet.from_pretrained('efficientnet-b5',
#                                          weights_path='weights/pose2/pose.best.pth.tar',
#                                          num_classes=2, load_fc=True)
posemodel = EfficientNet.from_pretrained('efficientnet-b5',
                                         weights_path='weights/pose10extra/pose.best.pth.tar',
                                         num_classes=10, load_fc=True)
posemodel.to(device)
posemodel.eval()
if half:
    posemodel.half()  # to FP16

# transform after crop
tfms = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
detectFall = 0
detectOther = 0

def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
    # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
    shape = img.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better test mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, 64), np.mod(dh, 64)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return img, ratio, (dw, dh)


def plot_one_box(x, img, color=None, label=None, line_thickness=None):
    # Plots one bounding box on image img
    tl = line_thickness or round(
        0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
    color = color or [random.randint(0, 255) for _ in range(3)]
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3,
                    [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)


def plot_action(img, color=None, line_thickness=None, action=None):
    tl = line_thickness or round(
        0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
    color = color or [random.randint(0, 255) for _ in range(3)]
    tf = max(tl - 1, 1)  # font thickness
    t_size1 = cv2.getTextSize(action, 0, fontScale=tl / 3, thickness=tf)[0]
    c3 = (100, 100)
    # c3 = (1, img.shape[0] - 1)
    c4 = (c3[0] + t_size1[0], c3[1] - t_size1[1] - 3)
    cv2.rectangle(img, c3, c4, color, -1, cv2.LINE_AA)  # filled
    cv2.putText(img, action, (c3[0], c3[1] - 2), 0, tl / 3,
                [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)


def weightedPose(poseList):
    weightedList = []
    length = len(poseList)
    for i in range(length):
        leftIndex = max(0, i - 2)
        rightIndex = min(length - 1, i + 2)
        windowSum = sum(poseList[leftIndex: rightIndex + 1])
        weight = windowSum - poseList[i]
        weightedList.append(weight)
    return list(map(lambda x, y: x * y, weightedList, poseList))


def fallDetection(img0):
    global detectFall, detectOther, proposalFlag
    img_org = copy.deepcopy(img0)
    height, width, _ = img0.shape[0], img0.shape[1], img0.shape[2]
    img = letterbox(img0, new_shape=imgsize)[0]
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)
    img = torch.from_numpy(img).to(device)
    img = img.half() if half else img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)

    # convert to PIL image
    pilimg = Image.fromarray(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB))

    # detection process
    pred = model(img, augment=False)[0]
    pred = non_max_suppression(pred, confthres, iouthres, classes=None, agnostic=False)
    result = []
    poseimglist = []
    poseOutputList = []
    normalflag = True

    for i, det in enumerate(pred):
        if det is not None and len(det):
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
            for *xyxyTensor, conf, cls in reversed(det):
                xyxy = torch.tensor(xyxyTensor).view(1, 4).view(-1).tolist()
                height = xyxy[3] - xyxy[1]
                if cls == 0 and height >= 100:
                # if cls == 0:
                    # cropstart = time.time()
                    cropleft, croptop, cropright, cropbottom = xyxy
                    cropped = pilimg.crop((cropleft, croptop, cropright, cropbottom))
                    cropped = tfms(cropped)
                    poseimglist.append(cropped)
                    # cropend = time.time()
                    # print('crop single box time: {:.3f}'.format(float(cropend - cropstart)))
                    result.append(xyxy)

    # posestart = time.time()
    # classnames = ["abnormal"->0, "normal"->1]
    # poseNames = ["bend", "fall", "jump", "lie", "ride", "run", "sit", "squat", "stand", "walk"]
    if len(poseimglist):
        posemodelinput = torch.stack(poseimglist, dim=0)
        posemodelinput = posemodelinput.to(device)
        posemodelinput = posemodelinput.half() if half else posemodelinput.float()
        with torch.no_grad():
            output = posemodel(posemodelinput)
            output = F.softmax(output, dim=1)
            prob, pred = output.topk(1, 1, True, True)
            for i in range(len(pred)):
                poseOutputList.append([poseNames[pred[i].item()], prob[i].item()])
                # normalflag = pred[i].item() == 1
                # if not normalflag:
                #     break
    # poseend = time.time()
    # print('pose single frame time: {:.3f}'.format(float(poseend - posestart)))

    if len(detected) >= frame_number:
        imagedatalist.pop(0)
        resultlist.pop(0)
        resultPoseList.pop(0)
        detected.pop(0)
        posedetected.pop(0)

    imagedatalist.append(img_org)
    resultlist.append(result)
    resultPoseList.append(poseOutputList)
    if len(result):
        detected.append(1)
    else:
        detected.append(0)
    
    # 异常判断逻辑
    if len(poseOutputList):
        tmp = set([x[0] for x in poseOutputList])
        # if "fall" in tmp or "lie" in tmp or "jump" in tmp:
        if "fall" in tmp or "lie" in tmp:
            posedetected.append(2)
        elif "squat" in tmp or "sit" in tmp:
            posedetected.append(1)
        else:
            posedetected.append(0)
    else:
        posedetected.append(0)
    # print(*posedetected, sep=' ')
    if(len(posedetected) == frame_number):
        # 生成proposal条件：加权后的异常值大于4，即比[0, 0, 0, 0, 0, 0, 1, 2]要好，且存在一次典型姿态
        newPose = weightedPose(posedetected)
        proposalFlag = 2 in posedetected and sum(newPose) >= 8
        # proposalFlag = sum(posedetected) >= 4

    # TODO filter condition
    if proposalFlag:
        #TODO TSM
        tsmPred, tsmConf = fall_TSM_Module.alertAction(imagedatalist)
        actionLabel = f'{tsmPred} {tsmConf:.2f}%'
        if tsmPred == 'fall':
            detectFall += 1
            # detectOther += 1
            for index, fallFrame in enumerate(imagedatalist):
                for (resultBox, resultPose) in zip(resultlist[index], resultPoseList[index]):
                    # label = f'{resultPose[0]} {resultPose[1]:.2f}'
                    # label = f'{resultPose[0]}'
                    label = f'{resultPose[0]} {resultPose[1]:.2f} {resultBox[3] - resultBox[1]}'
                    # plot_one_box(resultBox, fallFrame, label=label,
                    #              color=[0, 0, 255], line_thickness=3)
                    plot_one_box(resultBox, fallFrame, label=label)
                plot_action(fallFrame, action=actionLabel)
                savepath = os.path.join('data/falltest/all', '{0}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
                    tsmPred, detectFall, index, detected[index], posedetected[index]))
                # savepath = os.path.join('data/falltest/all', '{0}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
                #     detectOther, tsmPred, index, detected[index], posedetected[index]))
                cv2.imwrite(savepath, fallFrame)
        else:
            detectOther += 1
            # for index, fallFrame in enumerate(imagedatalist):
            #     for (resultBox, resultPose) in zip(resultlist[index], resultPoseList[index]):
            #         label = f'{resultPose[0]} {resultPose[1]:.2f} {resultBox[3] - resultBox[1]}'
            #         plot_one_box(resultBox, fallFrame, label=label)
            #     plot_action(fallFrame, action=actionLabel)
            #     savepath = os.path.join('data/falltest/all', '{0}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
            #         tsmPred, detectOther, index, detected[index], posedetected[index]))
            #     cv2.imwrite(savepath, fallFrame)
        resultlist.clear()
        resultPoseList.clear()
        imagedatalist.clear()
        detected.clear()
        posedetected.clear()
        proposalFlag = False


def main():
    os.system('rm -rf ./data/falltest/all/*')
    videoPath = 'data/input/luoyan.mp4'
    cap = cv2.VideoCapture(videoPath)
    # cap = cv2.VideoCapture('rtsp://admin:ivlab2019@192.168.104.233//Streaming/Channels/1')
    fNUMS = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    fps = cap.get(cv2.CAP_PROP_FPS)
    print("FPS is ", fps)
    # saveImageNew = './save/img'
    # saveVideo = './save/mp4'

    frameCount = 0
    print("Starting...")
    start = time.time()
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            frameCount += 1
            if frameCount % 6 == 0:
                fallDetection(frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        else:
            break
        if frameCount % 1000 == 0:
            print("Current :", round(frameCount / float(fNUMS) * 100, 1), "%")
    cap.release()
    end = time.time()
    print("all done")
    print("processing time : {0:.1f}, total : {1}, fall : {2}".format(
        float(end - start), detectFall + detectOther, detectFall))


if __name__ == "__main__":
    main()
